#  Copyright (c): Wenyi Tang 2017-2019.
#  Author: Wenyi Tang
#  Email: wenyi.tang@intel.com
#  Update Date: 2019 - 3 - 15

from collections import OrderedDict

import torch
import torch.nn as nn


####################
# Basic blocks
####################


def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
  # helper selecting activation
  # neg_slope: for leakyrelu and init of prelu
  # n_prelu: for p_relu num_parameters
  act_type = act_type.lower()
  if act_type == 'relu':
    layer = nn.ReLU(inplace)
  elif act_type == 'leakyrelu':
    layer = nn.LeakyReLU(neg_slope, inplace)
  elif act_type == 'prelu':
    layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
  else:
    raise NotImplementedError('activation layer [%s] is not found' % act_type)
  return layer


def norm(norm_type, nc):
  # helper selecting normalization layer
  norm_type = norm_type.lower()
  if norm_type == 'batch':
    layer = nn.BatchNorm2d(nc, affine=True)
  elif norm_type == 'instance':
    layer = nn.InstanceNorm2d(nc, affine=False)
  else:
    raise NotImplementedError(
      'normalization layer [%s] is not found' % norm_type)
  return layer


def pad(pad_type, padding):
  # helper selecting padding layer
  # if padding is 'zero', do by conv layers
  pad_type = pad_type.lower()
  if padding == 0:
    return None
  if pad_type == 'reflect':
    layer = nn.ReflectionPad2d(padding)
  elif pad_type == 'replicate':
    layer = nn.ReplicationPad2d(padding)
  else:
    raise NotImplementedError(
      'padding layer [%s] is not implemented' % pad_type)
  return layer


def get_valid_padding(kernel_size, dilation):
  kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
  padding = (kernel_size - 1) // 2
  return padding


class ConcatBlock(nn.Module):
  # Concat the output of a submodule to its input
  def __init__(self, submodule):
    super(ConcatBlock, self).__init__()
    self.sub = submodule

  def forward(self, x):
    output = torch.cat((x, self.sub(x)), dim=1)
    return output

  def __repr__(self):
    tmpstr = 'Identity .. \n|'
    modstr = self.sub.__repr__().replace('\n', '\n|')
    tmpstr = tmpstr + modstr
    return tmpstr


class ShortcutBlock(nn.Module):
  # Elementwise sum the output of a submodule to its input
  def __init__(self, submodule):
    super(ShortcutBlock, self).__init__()
    self.sub = submodule

  def forward(self, x):
    output = x + self.sub(x)
    return output

  def __repr__(self):
    tmpstr = 'Identity + \n|'
    modstr = self.sub.__repr__().replace('\n', '\n|')
    tmpstr = tmpstr + modstr
    return tmpstr


def sequential(*args):
  # Flatten Sequential. It unwraps nn.Sequential.
  if len(args) == 1:
    if isinstance(args[0], OrderedDict):
      raise NotImplementedError(
        'sequential does not support OrderedDict input.')
    return args[0]  # No sequential is needed.
  modules = []
  for module in args:
    if isinstance(module, nn.Sequential):
      for submodule in module.children():
        modules.append(submodule)
    elif isinstance(module, nn.Module):
      modules.append(module)
  return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1,
               bias=True,
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
  """
  Conv layer with padding, normalization, activation
  mode: CNA --> Conv -> Norm -> Act
      NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
  """
  assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [%s]' % mode
  padding = get_valid_padding(kernel_size, dilation)
  p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
  padding = padding if pad_type == 'zero' else 0

  c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride,
                padding=padding, \
                dilation=dilation, bias=bias, groups=groups)
  a = act(act_type) if act_type else None
  if 'CNA' in mode:
    n = norm(norm_type, out_nc) if norm_type else None
    return sequential(p, c, n, a)
  elif mode == 'NAC':
    if norm_type is None and act_type is not None:
      a = act(act_type, inplace=False)
      # Important!
      # input----ReLU(inplace)----Conv--+----output
      #        |________________________|
      # inplace ReLU will modify the input, therefore wrong output
    n = norm(norm_type, in_nc) if norm_type else None
    return sequential(n, a, p, c)


####################
# Useful blocks
####################


class ResNetBlock(nn.Module):
  """
  ResNet Block, 3-3 style
  with extra residual scaling used in EDSR
  (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
  """

  def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1,
               groups=1, \
               bias=True, pad_type='zero', norm_type=None, act_type='relu',
               mode='CNA', res_scale=1):
    super(ResNetBlock, self).__init__()
    conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups,
                       bias, pad_type, \
                       norm_type, act_type, mode)
    if mode == 'CNA':
      act_type = None
    if mode == 'CNAC':  # Residual path: |-CNAC-|
      act_type = None
      norm_type = None
    conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups,
                       bias, pad_type, \
                       norm_type, act_type, mode)
    # if in_nc != out_nc:
    #     self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
    #         None, None)
    #     print('Need a projecter in ResNetBlock.')
    # else:
    #     self.project = lambda x:x
    self.res = sequential(conv0, conv1)
    self.res_scale = res_scale

  def forward(self, x):
    res = self.res(x).mul(self.res_scale)
    return x + res


class ResidualDenseBlock_5C(nn.Module):
  """
  Residual Dense Block
  style: 5 convs
  The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
  """

  def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True,
               pad_type='zero', \
               norm_type=None, act_type='leakyrelu', mode='CNA'):
    super(ResidualDenseBlock_5C, self).__init__()
    # gc: growth channel, i.e. intermediate channels
    self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias,
                            pad_type=pad_type, \
                            norm_type=norm_type, act_type=act_type, mode=mode)
    self.conv2 = conv_block(nc + gc, gc, kernel_size, stride, bias=bias,
                            pad_type=pad_type, \
                            norm_type=norm_type, act_type=act_type, mode=mode)
    self.conv3 = conv_block(nc + 2 * gc, gc, kernel_size, stride, bias=bias,
                            pad_type=pad_type, \
                            norm_type=norm_type, act_type=act_type, mode=mode)
    self.conv4 = conv_block(nc + 3 * gc, gc, kernel_size, stride, bias=bias,
                            pad_type=pad_type, \
                            norm_type=norm_type, act_type=act_type, mode=mode)
    if mode == 'CNA':
      last_act = None
    else:
      last_act = act_type
    self.conv5 = conv_block(nc + 4 * gc, nc, 3, stride, bias=bias,
                            pad_type=pad_type, \
                            norm_type=norm_type, act_type=last_act, mode=mode)

  def forward(self, x):
    x1 = self.conv1(x)
    x2 = self.conv2(torch.cat((x, x1), 1))
    x3 = self.conv3(torch.cat((x, x1, x2), 1))
    x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
    x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
    return x5.mul(0.2) + x


class RRDB(nn.Module):
  """
  Residual in Residual Dense Block
  """

  def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True,
               pad_type='zero', \
               norm_type=None, act_type='leakyrelu', mode='CNA'):
    super(RRDB, self).__init__()
    self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias,
                                      pad_type, \
                                      norm_type, act_type, mode)
    self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias,
                                      pad_type, \
                                      norm_type, act_type, mode)
    self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias,
                                      pad_type, \
                                      norm_type, act_type, mode)

  def forward(self, x):
    out = self.RDB1(x)
    out = self.RDB2(out)
    out = self.RDB3(out)
    return out.mul(0.2) + x


####################
# Upsampler
####################


def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1,
                       bias=True,
                       pad_type='zero', norm_type=None, act_type='relu'):
  """
  Pixel shuffle layer
  (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
  Neural Network, CVPR17)
  """
  conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride,
                    bias=bias,
                    pad_type=pad_type, norm_type=None, act_type=None)
  pixel_shuffle = nn.PixelShuffle(upscale_factor)

  n = norm(norm_type, out_nc) if norm_type else None
  a = act(act_type) if act_type else None
  return sequential(conv, pixel_shuffle, n, a)


def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1,
                 bias=True,
                 pad_type='zero', norm_type=None, act_type='relu',
                 mode='nearest'):
  # Up conv
  # described in https://distill.pub/2016/deconv-checkerboard/
  upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
  conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
                    pad_type=pad_type, norm_type=norm_type, act_type=act_type)
  return sequential(upsample, conv)
